import os
import cv2
from tqdm import tqdm
from PIL import Image
from torch.utils import data
from torchvision import transforms

from .preproc import preproc
from .config import Config
from glob import glob


Image.MAX_IMAGE_PIXELS = None       # remove DecompressionBombWarning
config = Config()
_class_labels_TR_sorted = 'Airplane, Ant, Antenna, Archery, Axe, BabyCarriage, Bag, BalanceBeam, Balcony, Balloon, Basket, BasketballHoop, Beatle, Bed, Bee, Bench, Bicycle, BicycleFrame, BicycleStand, Boat, Bonsai, BoomLift, Bridge, BunkBed, Butterfly, Button, Cable, CableLift, Cage, Camcorder, Cannon, Canoe, Car, CarParkDropArm, Carriage, Cart, Caterpillar, CeilingLamp, Centipede, Chair, Clip, Clock, Clothes, CoatHanger, Comb, ConcretePumpTruck, Crack, Crane, Cup, DentalChair, Desk, DeskChair, Diagram, DishRack, DoorHandle, Dragonfish, Dragonfly, Drum, Earphone, Easel, ElectricIron, Excavator, Eyeglasses, Fan, Fence, Fencing, FerrisWheel, FireExtinguisher, Fishing, Flag, FloorLamp, Forklift, GasStation, Gate, Gear, Goal, Golf, GymEquipment, Hammock, Handcart, Handcraft, Handrail, HangGlider, Harp, Harvester, Headset, Helicopter, Helmet, Hook, HorizontalBar, Hydrovalve, IroningTable, Jewelry, Key, KidsPlayground, Kitchenware, Kite, Knife, Ladder, LaundryRack, Lightning, Lobster, Locust, Machine, MachineGun, MagazineRack, Mantis, Medal, MemorialArchway, Microphone, Missile, MobileHolder, Monitor, Mosquito, Motorcycle, MovingTrolley, Mower, MusicPlayer, MusicStand, ObservationTower, Octopus, OilWell, OlympicLogo, OperatingTable, OutdoorFitnessEquipment, Parachute, Pavilion, Piano, Pipe, PlowHarrow, PoleVault, Punchbag, Rack, Racket, Rifle, Ring, Robot, RockClimbing, Rope, Sailboat, Satellite, Scaffold, Scale, Scissor, Scooter, Sculpture, Seadragon, Seahorse, Seal, SewingMachine, Ship, Shoe, ShoppingCart, ShoppingTrolley, Shower, Shrimp, Signboard, Skateboarding, Skeleton, Skiing, Spade, SpeedBoat, Spider, Spoon, Stair, Stand, Stationary, SteeringWheel, Stethoscope, Stool, Stove, StreetLamp, SweetStand, Swing, Sword, TV, Table, TableChair, TableLamp, TableTennis, Tank, Tapeline, Teapot, Telescope, Tent, TobaccoPipe, Toy, Tractor, TrafficLight, TrafficSign, Trampoline, TransmissionTower, Tree, Tricycle, TrimmerCover, Tripod, Trombone, Truck, Trumpet, Tuba, UAV, Umbrella, UnevenBars, UtilityPole, VacuumCleaner, Violin, Wakesurfing, Watch, WaterTower, WateringPot, Well, WellLid, Wheel, Wheelchair, WindTurbine, Windmill, WineGlass, WireWhisk, Yacht'
class_labels_TR_sorted = _class_labels_TR_sorted.split(', ')


class MyData(data.Dataset):
    def __init__(self, data_root, image_size, is_train=True):
        self.size_train = image_size
        self.size_test = image_size
        self.keep_size = not config.size
        self.data_size = (config.size, config.size)
        self.is_train = is_train
        self.load_all = config.load_all
        self.device = config.device
        self.dataset = data_root.replace('\\', '/').split('/')[-1]
        if self.is_train and config.auxiliary_classification:
            self.cls_name2id = {_name: _id for _id, _name in enumerate(class_labels_TR_sorted)}
        self.transform_image = transforms.Compose([
            transforms.Resize(self.data_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ][self.load_all or self.keep_size:])
        self.transform_label = transforms.Compose([
            transforms.Resize(self.data_size),
            transforms.ToTensor(),
        ][self.load_all or self.keep_size:])
        ## 'im' and 'gt' need modifying
        image_root = os.path.join(data_root, 'im')
        self.image_paths = [os.path.join(image_root, p) for p in os.listdir(image_root)]
        self.label_paths = [p.replace('/im/', '/gt/').replace('.jpg', '.png') for p in self.image_paths]
        if self.load_all:
            self.images_loaded, self.labels_loaded = [], []
            self.class_labels_loaded = []
            # for image_path, label_path in zip(self.image_paths, self.label_paths):
            for image_path, label_path in tqdm(zip(self.image_paths, self.label_paths), total=len(self.image_paths)):
                _image = cv2.imread(image_path)
                _label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
                if not self.keep_size:
                    _image_rs = cv2.resize(_image, (config.size, config.size), interpolation=cv2.INTER_LINEAR)
                    _label_rs = cv2.resize(_label, (config.size, config.size), interpolation=cv2.INTER_LINEAR)
                self.images_loaded.append(
                    Image.fromarray(cv2.cvtColor(_image_rs, cv2.COLOR_BGR2RGB)).convert('RGB')
                )
                self.labels_loaded.append(
                    Image.fromarray(_label_rs).convert('L')
                )
                self.class_labels_loaded.append(
                    self.cls_name2id[label_path.split('/')[-1].split('#')[3]] if self.is_train and config.auxiliary_classification else -1
                )


    def __getitem__(self, index):

        if self.load_all:
            image = self.images_loaded[index]
            class_label = self.class_labels_loaded[index] if self.is_train and config.auxiliary_classification else -1
        else:
            image = Image.open(self.image_paths[index]).convert('RGB')

        # loading image and label
        if self.is_train:
            image, label = preproc(image, image, preproc_methods=config.preproc_methods)
        # else:
        #     if _label.shape[0] > 2048 or _label.shape[1] > 2048:
        #         _image = cv2.resize(_image, (2048, 2048), interpolation=cv2.INTER_LINEAR)
        #         _label = cv2.resize(_label, (2048, 2048), interpolation=cv2.INTER_LINEAR)

        image, label = self.transform_image(image), self.transform_label(label)

        if self.is_train:
            return image, label, class_label
        else:
            return image, label, self.label_paths[index]

    def __len__(self):
        return len(self.image_paths)


class YouData(data.Dataset):
    def __init__(self, data_root, image_size, is_train=True):
        self.size_train = image_size
        self.size_test = image_size
        self.keep_size = not config.size
        self.data_size = (config.size, config.size)
        self.is_train = is_train
        self.load_all = config.load_all
        self.device = config.device
        self.dataset = data_root.replace('\\', '/').split('/')[-1]
        if self.is_train and config.auxiliary_classification:
            self.cls_name2id = {_name: _id for _id, _name in enumerate(class_labels_TR_sorted)}
        self.transform_image = transforms.Compose([
            transforms.Resize(self.data_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ][self.load_all or self.keep_size:])
        ## 'im' and 'gt' need modifying
        self.image_paths = glob(data_root + "/*")
        self.img_sizes = []
        if self.load_all:
            self.images_loaded, self.labels_loaded = [], []
            for image_path in tqdm(self.image_paths, total=len(self.image_paths)):
                _image = cv2.imread(image_path)
                if not self.keep_size:
                    _image_rs = cv2.resize(_image, (config.size, config.size), interpolation=cv2.INTER_LINEAR)
                self.images_loaded.append(
                    Image.fromarray(cv2.cvtColor(_image_rs, cv2.COLOR_BGR2RGB)).convert('RGB')
                )
                self.img_sizes.append(_image.shape[:2])


    def __getitem__(self, index):

        if self.load_all:
            image = self.images_loaded[index]
        else:
            image = Image.open(self.image_paths[index]).convert('RGB')

        # loading image and label
        if self.is_train:
            image, _ = preproc(image, image, preproc_methods=config.preproc_methods)

        image = self.transform_image(image)
        size = self.img_sizes[index]
        return image, size

    def __len__(self):
        return len(self.image_paths)
